DeepEM: A Deep Neural Network for DEM Inversion

by Paul Wright$^{1}$, Mark Cheung, Rajat Thomas, Richard Galvez, Alexandre Szenicer, Meng Jin, Andres Munoz-Jaramillo, and David Fouhey

$^{1}$ University of Glasgow; email: paul@pauljwright.co.uk

The intensity observed through optically-thin SDO/AIA filters (94 Ã…, 131 Ã…, 171 Ã…, 193 Ã…, 211 Ã…, 335 Ã…) can be related to the temperature distribution of the solar corona (the differential emission measure; DEM) as

\begin{equation} g_{i} = \int_{T} K_{i}(T) \xi(T) dT \, . \end{equation}

In this equation, $g_{i}$ is the DN s$^{-1}$ px$^{-1}$ value in the $i$th SDO/AIA channel. This intensity corresponds to the $K_{i}(T)$ temperature response function, and the DEM, $\xi(T)$, is in units of cm$^{-5}$ K$^{-1}$. The matrix formulation of this integral equation can be represented in the form, $\vec{g} = {\bf K}\vec{\xi}$, however, this problem is an ill-posed inverse problem, and any attempt to directly recover $\vec{\xi}$ leads to significant noise amplication.

There are numerous methods to tackle mathematical problems of this kind, and there are an increasing number of methods in the literature for recovering the differential emission measure including methods based tecniques such as Tikhonov Regularisation (Hannah & Kontar 2012), on the concept of sparsity (Cheung et al 2015). In the following notebook, we will demonstrate how a simple 1x1 2D convolutional neural network allows for significant improvement in computational speed for DEM inversion with similar fidelity to the method used for training (Basis Pursuit). Additionally this method, DeepEM, provides solutions with values of emission measure >0 in every temperature bin.

DeepEM: A Deep Learning Approach for DEM Inversion

Paul J. Wright, Mark Cheung, Rajat Thomas, Richard Galvez, Alexandre Szenicer, Meng Jin, Andres Munoz-Jaramillo, and David Fouhey


In this chapter we will introduce a Deep Learning approach for DEM Inversion. For this notebook, DeepEM is a trained on one set of SDO/AIA observations (six optically thin channels; $6 \times N \times N$) and DEM solutions (in 18 temperature bins from log$_{10}$T = 5.5 - 7.2, $18 \times N \times N$; Cheung et al 2015) at a resolution of $512 \times 512$ ($N = 512$) using a $1 \times 1$ 2D Convolutional Neural Network with a single hidden layer.

The DeepEM method presented here takes every DEM solution with no regards to the quality or existence of the solution. As will be demonstrated, when this method is trained with a single set images and DEM solutions, the DeepEM solutions have a similar fidelity to Sparse Inversion (with a significantly increased computation speed), and additionally, the DeepEM solutions find positive solutions at every pixel, and reduced noise in the DEM solutions.

In [1]:
import os
import json
import time
import torch
import numpy as np
import pandas as pd
import torch.nn as nn
from scipy.io import readsav
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
from torch.autograd import Variable
from torch.utils.data import DataLoader

%matplotlib inline
%config InlineBackend.figure_format = 'retina'
In [2]:
def em_scale(y):
    return np.sqrt(y/1e25)

def em_unscale(y):
    return 1e25*(y*y)

def img_scale(x):
    x2 = x
    bad = np.where(x2 <= 0.0)
    x2[bad] = 0.0
    return np.sqrt(x2)

def img_unscale(x):
    return x*x 

Step 1: Obtain Data and Sparse Inversion Solutions for Training

We first load the SDO/AIA images and Basis Pursuit DEM maps.

N.B. While this simplified version of DeepEM has been trained on DEM maps from Basis Pursuit (Cheung et al. 2015), we actively encourage the readers to try their favourite method for DEM inversion!

In [3]:
aia_files = ['AIA_DEM_2011-01-27','AIA_DEM_2011-02-22','AIA_DEM_2011-03-20']
em_cube_files = aia_files

for k, (afile, emfile) in enumerate(zip(aia_files, em_cube_files)):
    afile_name = os.path.join('./DeepEM_Data/', afile + '.aia.npy')
    emfile_name = os.path.join('./DeepEM_Data/', emfile + '.emcube.npy')
    if k == 0:
        X = np.load(afile_name)
        y = np.load(emfile_name)
 
        X = np.zeros((len(aia_files), X.shape[0], X.shape[1], X.shape[2]))
        y = np.zeros((len(em_cube_files), y.shape[0], y.shape[1], y.shape[2]))
        
        nlgT = y.shape[0]
        lgtaxis = np.arange(y.shape[1])*0.1 + 5.5
        
    X[k] = np.load(afile_name)
    y[k] = np.load(emfile_name) 

Step 2: Define the Model

We first define the model as a 1x1 2D Convolutional Neural Network (CNN) with a kernel size of 1x1. The model accepts a data cube of $6 \times N \times N$ (SDO/AIA data), and returns a data cube of $18 \times N \times N$ (DEM). which when trained, will transform the input (each pixel of the 6 SDO/AIA channels; $6 \times 1 \times 1$) to the output (DEM at each pixel; $18 \times 1 \times 1$).

In [4]:
model = nn.Sequential(
    nn.Conv2d(6, 300, kernel_size=1),
    nn.LeakyReLU(), #Activation function
    nn.Conv2d(300, 300, kernel_size=1),
    nn.LeakyReLU(),
    nn.Conv2d(300, 18, kernel_size=1)).cuda() #Loading model on to gpu

Step 3: Train the Model

For training our CNN we select one SDO/AIA data cube ($6\times512\times512$) and the corresponding Sparse Inversion DEM output ($18\times512\times512$). In the case presented here, we train the CNN on an image of the Sun obtained on 27-01-2011, validate on an image of the Sun obtained one synodic rotation later (+26 days; 22-02-2011), and finally test on an image another 26 days later (20-03-2011).

In [5]:
X = img_scale(X)
y = em_scale(y)

X_train = X[0:1] 
y_train = y[0:1] 

X_val = X[1:2] 
y_val = y[1:2] 

X_test = X[2:3] 
y_test = y[2:3]

Plotting SDO/AIA Observations ${\it vs.}$ Basis Pursuit DEM bins

For the test data set, the SDO/AIA images for 171 Ã…, 211 Ã…, and 94 Ã…, and the corresponding DEM bins near the peak sensitivity in these relative isothermal channel (logT = 6.3, 5.9) are shown in Figure 1. Figure 1 shows a set of SDO/AIA images (171 Ã…, 211 Ã…, and 94 Ã… [Left to Right]) with the corresponding DEM maps for temperature bins there are near the peak sensitivity of the SDO/AIA channel. Furthermore, it is clear from the DEM maps that a number of pixels that are $zero$. These pixels are primarily located off-disk, but there are a number of pixels on-disk that show this behaviour.

In [6]:
fig,ax=plt.subplots(ncols=3,figsize=(9*2,9))

ax[1].imshow(X_test[0,4,:,:],vmin=0.25,vmax=25,cmap='Greys_r')
ax[1].text(5, 512.-7.5, '${\it SDO}$/AIA 211 $\AA$', color="white", size='large')
ax[0].imshow(X_test[0,2,:,:],vmin=0.01,vmax=30,cmap='Greys_r')
ax[0].text(5, 512.-7.5, '${\it SDO}$/AIA 171 $\AA$', color="white", size='large')
ax[2].imshow(X_test[0,0,:,:],vmin=0.01,vmax=3,cmap='Greys_r')
ax[2].text(5, 512.-7.5, '${\it SDO}$/AIA 94 $\AA$', color="white", size='large')

for axes in ax:
    axes.get_xaxis().set_visible(False)
    axes.get_yaxis().set_visible(False)
    
fig,ax=plt.subplots(ncols=3,figsize=(9*2,9))

ax[1].imshow(y_test[0,8,:,:],vmin=0.25,vmax=10,cmap='viridis')
ax[1].text(5, 512.-7.5, 'Basis Pursuit DEM log$_{10}$T ~ 6.3', color="white", size='large')
#ax[1].scatter(x=[400], y=[80], c='w', s=50, marker='x') 
ax[0].imshow(y_test[0,4,:,:],vmin=0.01,vmax=3,cmap='viridis')
ax[0].text(5, 512.-7.5, 'Basis Pursuit DEM log$_{10}$T ~ 5.9', color="white", size='large')
ax[2].imshow(y_test[0,15,:,:],vmin=0.01,vmax=3,cmap='viridis')
ax[2].text(5, 512.-7.5, 'Basis Pursuit DEM log$_{10}$T ~ 7.0', color="white", size='large')

for axes in ax:
    axes.get_xaxis().set_visible(False)
    axes.get_yaxis().set_visible(False)

Figure 1: Left to Right: SDO/AIA images in 171 Ã…, 211 Ã…, and 94 Ã… (top), with the corresponding DEM bins (chosen at the peak sensitivity of each of the SDO/AIA channels) shown below. In the DEM bins (bottom) it is clear that there are some pixels that have solutions of DEM = $zero$, as explicitly seen as dark regions/clusters of pixels on and off disk.


To implement training and testing of our model, we first define a DEMdata class, and define functions for training and validation/test: train_model, and valtest_model.

N.B. It is not necessary to train the model, and if required, the trained model can be loaded to the cpu as follows:

model = nn.Sequential(
    nn.Conv2d(6, 300, kernel_size=1),
    nn.LeakyReLU(),
    nn.Conv2d(300, 300, kernel_size=1),
    nn.LeakyReLU(),
    nn.Conv2d(300, 18, kernel_size=1))

dem_model_file = 'DeepEM_CNN_HelioML.pth'
model.load_state_dict(torch.load(dem_model_file, map_location='cpu'))

Once you have loaded the the model, skip to Step 4: Testing the Model.

In [7]:
class DEMdata(nn.Module):
    def __init__(self, xtrain, ytrain, xtest, ytest, xval, yval, split='train'):
        
        if split == 'train':
            self.x = xtrain
            self.y = ytrain
        if split == 'val':
            self.x = xval
            self.y = yval
        if split == 'test':
            self.x = xtest
            self.y = ytest
            
    def __getitem__(self, index):
        return torch.from_numpy(self.x[index]).type(torch.FloatTensor), torch.from_numpy(self.y[index]).type(torch.FloatTensor)

    def __len__(self):
        return self.x.shape[0]
In [8]:
def train_model(dem_loader, criterion, optimizer, epochs=500):
    model.train()
    train_loss_all_batches = []
    train_loss_epoch = []
    train_val = []
    for k in range(epochs):
        count_ = 0
        avg_loss = 0
        # =================== progress indicator ==============
        if k % ((epochs + 1) // 4) == 0:
            print('[{0}]: {1:.1f}% complete: '.format(k, k / epochs * 100))
        # =====================================================
        for img, dem in dem_loader:
            count_ += 1
            optimizer.zero_grad()
            # =================== forward =====================
            img = img.cuda()
            dem = dem.cuda()

            output = model(img) 
            loss = criterion(output, dem)

            loss.backward()
            optimizer.step()
            
            train_loss_all_batches.append(loss.item())
            avg_loss += loss.item()
        # =================== Validation ===================
        dem_data_val = DEMdata(X_train, y_train, X_test, y_test, X_val, y_val, split='val')
        dem_loader_val = DataLoader(dem_data_val, batch_size=1)
        val_loss, dummy, dem_pred_val, dem_in_test_val = valtest_model(dem_loader_val, criterion)
        
        train_loss_epoch.append(avg_loss/count_)
        train_val.append(val_loss)
        
        print('Epoch: ', k, 'trn_loss: ', avg_loss/count_, 'val_loss: ', train_val[k])
            
    torch.save(model.state_dict(), 'DeepEM_CNN_HelioML.pth')
    return train_loss_epoch, train_val

def valtest_model(dem_loader, criterion):

    model.eval()
    
    val_loss = 0
    count = 0
    test_loss = []
    for img, dem in dem_loader:
        count += 1
        # =================== forward =====================
        img = img.cuda()
        dem = dem.cuda()
        
        output = model(img)
        loss = criterion(output, dem)
        test_loss.append(loss.item())
        val_loss += loss.item()
        
    return val_loss/count, test_loss, output, dem

We choose the Adam optimiser with a learning rate of 1e-4, and weight_decay set to 1e-9. We use Mean Squared Error (MSE) between the Sparse Inversion DEM map and the DeepEM map as our loss function.

In [9]:
optimizer = torch.optim.Adam(model.parameters(), lr=0.0001, weight_decay=1e-9); 
criterion = nn.MSELoss().cuda()

Using the defined functions, dem_data will return the training data, and this will be loaded by the DataLoader with batch_size=1 (one 512 x 512 image per batch). For each epoch, train_loss and valdn_loss will be returned by train_model

In [10]:
dem_data = DEMdata(X_train, y_train, X_test, y_test, X_val, y_val, split='train')
dem_loader = DataLoader(dem_data, batch_size=1)

t0=time.time() #Timing how long it takes to predict the DEMs
train_loss, valdn_loss = train_model(dem_loader, criterion, optimizer, epochs=500)
ttime = "Training time = {0} seconds".format(time.time()-t0)
print(ttime)
[0]: 0.0% complete: 
Epoch:  0 trn_loss:  2.211900472640991 val_loss:  2.3412883281707764
Epoch:  1 trn_loss:  2.0418665409088135 val_loss:  2.1650094985961914
Epoch:  2 trn_loss:  1.8821855783462524 val_loss:  1.9994359016418457
Epoch:  3 trn_loss:  1.732445240020752 val_loss:  1.8444578647613525
Epoch:  4 trn_loss:  1.5925333499908447 val_loss:  1.6999993324279785
Epoch:  5 trn_loss:  1.4623682498931885 val_loss:  1.5664299726486206
Epoch:  6 trn_loss:  1.3422213792800903 val_loss:  1.444068431854248
Epoch:  7 trn_loss:  1.2323580980300903 val_loss:  1.3330515623092651
Epoch:  8 trn_loss:  1.1327922344207764 val_loss:  1.2327927350997925
Epoch:  9 trn_loss:  1.043084979057312 val_loss:  1.1428008079528809
Epoch:  10 trn_loss:  0.9627913236618042 val_loss:  1.062804937362671
Epoch:  11 trn_loss:  0.8916780948638916 val_loss:  0.9925096035003662
Epoch:  12 trn_loss:  0.8294593691825867 val_loss:  0.9315062761306763
Epoch:  13 trn_loss:  0.7757624983787537 val_loss:  0.8793330788612366
Epoch:  14 trn_loss:  0.7301586866378784 val_loss:  0.8355310559272766
Epoch:  15 trn_loss:  0.6921866536140442 val_loss:  0.7994520664215088
Epoch:  16 trn_loss:  0.66126549243927 val_loss:  0.7702910900115967
Epoch:  17 trn_loss:  0.6366287469863892 val_loss:  0.7472808361053467
Epoch:  18 trn_loss:  0.6175400614738464 val_loss:  0.7296483516693115
Epoch:  19 trn_loss:  0.6032446622848511 val_loss:  0.7165414094924927
Epoch:  20 trn_loss:  0.592952311038971 val_loss:  0.7071173191070557
Epoch:  21 trn_loss:  0.5858771204948425 val_loss:  0.7005342841148376
Epoch:  22 trn_loss:  0.5812318325042725 val_loss:  0.695943295955658
Epoch:  23 trn_loss:  0.578205406665802 val_loss:  0.6926213502883911
Epoch:  24 trn_loss:  0.5761073231697083 val_loss:  0.689967930316925
Epoch:  25 trn_loss:  0.5743789672851562 val_loss:  0.6876103281974792
Epoch:  26 trn_loss:  0.5726650357246399 val_loss:  0.6852813363075256
Epoch:  27 trn_loss:  0.5707494020462036 val_loss:  0.6827114820480347
Epoch:  28 trn_loss:  0.5683932304382324 val_loss:  0.6797522306442261
Epoch:  29 trn_loss:  0.5654613375663757 val_loss:  0.6763668060302734
Epoch:  30 trn_loss:  0.5619438886642456 val_loss:  0.6725683212280273
Epoch:  31 trn_loss:  0.5578982830047607 val_loss:  0.668454110622406
Epoch:  32 trn_loss:  0.5534217953681946 val_loss:  0.6641428470611572
Epoch:  33 trn_loss:  0.5486329197883606 val_loss:  0.6596976518630981
Epoch:  34 trn_loss:  0.5436115860939026 val_loss:  0.6551660299301147
Epoch:  35 trn_loss:  0.5384458899497986 val_loss:  0.6505786180496216
Epoch:  36 trn_loss:  0.5332212448120117 val_loss:  0.645929217338562
Epoch:  37 trn_loss:  0.5279877185821533 val_loss:  0.6411880850791931
Epoch:  38 trn_loss:  0.522775411605835 val_loss:  0.6363391280174255
Epoch:  39 trn_loss:  0.5176140666007996 val_loss:  0.6313750743865967
Epoch:  40 trn_loss:  0.5125253200531006 val_loss:  0.6263282895088196
Epoch:  41 trn_loss:  0.5075493454933167 val_loss:  0.6212629675865173
Epoch:  42 trn_loss:  0.5027503967285156 val_loss:  0.61625736951828
Epoch:  43 trn_loss:  0.4981878995895386 val_loss:  0.6113520860671997
Epoch:  44 trn_loss:  0.4938960075378418 val_loss:  0.6065593361854553
Epoch:  45 trn_loss:  0.4898814558982849 val_loss:  0.6018846035003662
Epoch:  46 trn_loss:  0.48612701892852783 val_loss:  0.5973339080810547
Epoch:  47 trn_loss:  0.4826093912124634 val_loss:  0.5929082036018372
Epoch:  48 trn_loss:  0.4793020486831665 val_loss:  0.5886093974113464
Epoch:  49 trn_loss:  0.47617244720458984 val_loss:  0.5844330191612244
Epoch:  50 trn_loss:  0.47318172454833984 val_loss:  0.5803678631782532
Epoch:  51 trn_loss:  0.47028520703315735 val_loss:  0.5763902068138123
Epoch:  52 trn_loss:  0.467467725276947 val_loss:  0.5724876523017883
Epoch:  53 trn_loss:  0.4647180140018463 val_loss:  0.5686378479003906
Epoch:  54 trn_loss:  0.4620160460472107 val_loss:  0.5648790001869202
Epoch:  55 trn_loss:  0.4593978524208069 val_loss:  0.5613054633140564
Epoch:  56 trn_loss:  0.4569174349308014 val_loss:  0.5579603314399719
Epoch:  57 trn_loss:  0.45455965399742126 val_loss:  0.5547812581062317
Epoch:  58 trn_loss:  0.45222869515419006 val_loss:  0.5517075657844543
Epoch:  59 trn_loss:  0.44985538721084595 val_loss:  0.5487247109413147
Epoch:  60 trn_loss:  0.44742637872695923 val_loss:  0.5458287596702576
Epoch:  61 trn_loss:  0.4449532926082611 val_loss:  0.5429950952529907
Epoch:  62 trn_loss:  0.44245412945747375 val_loss:  0.5402013659477234
Epoch:  63 trn_loss:  0.439947247505188 val_loss:  0.5374335050582886
Epoch:  64 trn_loss:  0.43744030594825745 val_loss:  0.5346985459327698
Epoch:  65 trn_loss:  0.4349686801433563 val_loss:  0.5319944620132446
Epoch:  66 trn_loss:  0.43254879117012024 val_loss:  0.5293189287185669
Epoch:  67 trn_loss:  0.43018487095832825 val_loss:  0.5266442894935608
Epoch:  68 trn_loss:  0.427873432636261 val_loss:  0.5239692330360413
Epoch:  69 trn_loss:  0.4256311058998108 val_loss:  0.5213490724563599
Epoch:  70 trn_loss:  0.4235134720802307 val_loss:  0.5188652873039246
Epoch:  71 trn_loss:  0.4215609133243561 val_loss:  0.516473114490509
Epoch:  72 trn_loss:  0.41973820328712463 val_loss:  0.5140752792358398
Epoch:  73 trn_loss:  0.41797754168510437 val_loss:  0.511662483215332
Epoch:  74 trn_loss:  0.41624554991722107 val_loss:  0.5092543363571167
Epoch:  75 trn_loss:  0.41453346610069275 val_loss:  0.5068703293800354
Epoch:  76 trn_loss:  0.41283729672431946 val_loss:  0.5045294165611267
Epoch:  77 trn_loss:  0.41115421056747437 val_loss:  0.5022460222244263
Epoch:  78 trn_loss:  0.40948110818862915 val_loss:  0.5000209808349609
Epoch:  79 trn_loss:  0.4078102111816406 val_loss:  0.49784743785858154
Epoch:  80 trn_loss:  0.4061259627342224 val_loss:  0.49572160840034485
Epoch:  81 trn_loss:  0.40443387627601624 val_loss:  0.49366796016693115
Epoch:  82 trn_loss:  0.4027716815471649 val_loss:  0.4917200803756714
Epoch:  83 trn_loss:  0.4011861979961395 val_loss:  0.4898400902748108
Epoch:  84 trn_loss:  0.39964327216148376 val_loss:  0.48800188302993774
Epoch:  85 trn_loss:  0.39811819791793823 val_loss:  0.4861949682235718
Epoch:  86 trn_loss:  0.39660677313804626 val_loss:  0.4844139516353607
Epoch:  87 trn_loss:  0.395110547542572 val_loss:  0.4826553761959076
Epoch:  88 trn_loss:  0.39363014698028564 val_loss:  0.480916827917099
Epoch:  89 trn_loss:  0.39216628670692444 val_loss:  0.47919389605522156
Epoch:  90 trn_loss:  0.3907167613506317 val_loss:  0.47748252749443054
Epoch:  91 trn_loss:  0.3892790675163269 val_loss:  0.47577953338623047
Epoch:  92 trn_loss:  0.3878510892391205 val_loss:  0.4740828275680542
Epoch:  93 trn_loss:  0.38643163442611694 val_loss:  0.47239208221435547
Epoch:  94 trn_loss:  0.38501986861228943 val_loss:  0.47070789337158203
Epoch:  95 trn_loss:  0.38361620903015137 val_loss:  0.46903195977211
Epoch:  96 trn_loss:  0.3822202682495117 val_loss:  0.46736663579940796
Epoch:  97 trn_loss:  0.3808309733867645 val_loss:  0.4657129645347595
Epoch:  98 trn_loss:  0.37944719195365906 val_loss:  0.46407338976860046
Epoch:  99 trn_loss:  0.37806829810142517 val_loss:  0.4624505639076233
Epoch:  100 trn_loss:  0.3766936659812927 val_loss:  0.46084505319595337
Epoch:  101 trn_loss:  0.37532365322113037 val_loss:  0.4592565894126892
Epoch:  102 trn_loss:  0.3739582300186157 val_loss:  0.4576820731163025
Epoch:  103 trn_loss:  0.3725968897342682 val_loss:  0.4561167359352112
Epoch:  104 trn_loss:  0.37123903632164 val_loss:  0.4545555114746094
Epoch:  105 trn_loss:  0.36988365650177 val_loss:  0.45299386978149414
Epoch:  106 trn_loss:  0.3685300052165985 val_loss:  0.45142829418182373
Epoch:  107 trn_loss:  0.36717721819877625 val_loss:  0.449857234954834
Epoch:  108 trn_loss:  0.3658250868320465 val_loss:  0.4482802152633667
Epoch:  109 trn_loss:  0.3644733428955078 val_loss:  0.4466972351074219
Epoch:  110 trn_loss:  0.3631214201450348 val_loss:  0.44511058926582336
Epoch:  111 trn_loss:  0.3617689907550812 val_loss:  0.4435243606567383
Epoch:  112 trn_loss:  0.36041510105133057 val_loss:  0.44194233417510986
Epoch:  113 trn_loss:  0.3590589761734009 val_loss:  0.4403691291809082
Epoch:  114 trn_loss:  0.35770153999328613 val_loss:  0.4388073980808258
Epoch:  115 trn_loss:  0.356344997882843 val_loss:  0.43725645542144775
Epoch:  116 trn_loss:  0.3549913763999939 val_loss:  0.4357143044471741
Epoch:  117 trn_loss:  0.3536415100097656 val_loss:  0.4341791868209839
Epoch:  118 trn_loss:  0.3522948622703552 val_loss:  0.4326504170894623
Epoch:  119 trn_loss:  0.35095328092575073 val_loss:  0.4311279058456421
Epoch:  120 trn_loss:  0.34961754083633423 val_loss:  0.4296107292175293
Epoch:  121 trn_loss:  0.3482888638973236 val_loss:  0.4280991554260254
Epoch:  122 trn_loss:  0.3469684422016144 val_loss:  0.4265918731689453
Epoch:  123 trn_loss:  0.34565606713294983 val_loss:  0.4250877797603607
Epoch:  124 trn_loss:  0.3443503677845001 val_loss:  0.42358601093292236
[125]: 25.0% complete: 
Epoch:  125 trn_loss:  0.34305059909820557 val_loss:  0.42208409309387207
Epoch:  126 trn_loss:  0.34175553917884827 val_loss:  0.42057815194129944
Epoch:  127 trn_loss:  0.340461790561676 val_loss:  0.4190579652786255
Epoch:  128 trn_loss:  0.3391619324684143 val_loss:  0.4175065755844116
Epoch:  129 trn_loss:  0.3378371000289917 val_loss:  0.4159375727176666
Epoch:  130 trn_loss:  0.3365037739276886 val_loss:  0.4143933951854706
Epoch:  131 trn_loss:  0.3352038264274597 val_loss:  0.4128878116607666
Epoch:  132 trn_loss:  0.33392953872680664 val_loss:  0.41138914227485657
Epoch:  133 trn_loss:  0.3326501250267029 val_loss:  0.4098942279815674
Epoch:  134 trn_loss:  0.3313634991645813 val_loss:  0.4084019660949707
Epoch:  135 trn_loss:  0.3300723135471344 val_loss:  0.40691307187080383
Epoch:  136 trn_loss:  0.3287830948829651 val_loss:  0.40542739629745483
Epoch:  137 trn_loss:  0.3274977505207062 val_loss:  0.40394318103790283
Epoch:  138 trn_loss:  0.32621604204177856 val_loss:  0.4024600088596344
Epoch:  139 trn_loss:  0.3249310553073883 val_loss:  0.4009493291378021
Epoch:  140 trn_loss:  0.32359763979911804 val_loss:  0.3993157148361206
Epoch:  141 trn_loss:  0.322190523147583 val_loss:  0.3976942300796509
Epoch:  142 trn_loss:  0.3208380937576294 val_loss:  0.39615535736083984
Epoch:  143 trn_loss:  0.3195631206035614 val_loss:  0.39466655254364014
Epoch:  144 trn_loss:  0.3183048665523529 val_loss:  0.3932192027568817
Epoch:  145 trn_loss:  0.3170431852340698 val_loss:  0.39180830121040344
Epoch:  146 trn_loss:  0.3157873749732971 val_loss:  0.39041823148727417
Epoch:  147 trn_loss:  0.31454557180404663 val_loss:  0.389021635055542
Epoch:  148 trn_loss:  0.3133162260055542 val_loss:  0.3875892460346222
Epoch:  149 trn_loss:  0.3120916187763214 val_loss:  0.3861047327518463
Epoch:  150 trn_loss:  0.31086698174476624 val_loss:  0.38457128405570984
Epoch:  151 trn_loss:  0.3096438944339752 val_loss:  0.3830069899559021
Epoch:  152 trn_loss:  0.30842745304107666 val_loss:  0.3814362585544586
Epoch:  153 trn_loss:  0.3072212040424347 val_loss:  0.3798818588256836
Epoch:  154 trn_loss:  0.30602478981018066 val_loss:  0.37836116552352905
Epoch:  155 trn_loss:  0.3048354983329773 val_loss:  0.3768848180770874
Epoch:  156 trn_loss:  0.3036518096923828 val_loss:  0.37545573711395264
Epoch:  157 trn_loss:  0.30247437953948975 val_loss:  0.3740679919719696
Epoch:  158 trn_loss:  0.30130481719970703 val_loss:  0.3727072477340698
Epoch:  159 trn_loss:  0.30014386773109436 val_loss:  0.37135499715805054
Epoch:  160 trn_loss:  0.2989910840988159 val_loss:  0.36999407410621643
Epoch:  161 trn_loss:  0.2978460192680359 val_loss:  0.36861351132392883
Epoch:  162 trn_loss:  0.2967090904712677 val_loss:  0.3672102689743042
Epoch:  163 trn_loss:  0.29558125138282776 val_loss:  0.3657892644405365
Epoch:  164 trn_loss:  0.2944629490375519 val_loss:  0.36436042189598083
Epoch:  165 trn_loss:  0.29335376620292664 val_loss:  0.3629363477230072
Epoch:  166 trn_loss:  0.29225316643714905 val_loss:  0.36152949929237366
Epoch:  167 trn_loss:  0.2911612093448639 val_loss:  0.3601490259170532
Epoch:  168 trn_loss:  0.2900781035423279 val_loss:  0.35879674553871155
Epoch:  169 trn_loss:  0.2890019416809082 val_loss:  0.3574582636356354
Epoch:  170 trn_loss:  0.2879202365875244 val_loss:  0.3561129570007324
Epoch:  171 trn_loss:  0.28682029247283936 val_loss:  0.35476091504096985
Epoch:  172 trn_loss:  0.2857050895690918 val_loss:  0.3534592390060425
Epoch:  173 trn_loss:  0.2846214473247528 val_loss:  0.35221925377845764
Epoch:  174 trn_loss:  0.28358402848243713 val_loss:  0.3509608805179596
Epoch:  175 trn_loss:  0.28256985545158386 val_loss:  0.3496505916118622
Epoch:  176 trn_loss:  0.2815598249435425 val_loss:  0.3483080267906189
Epoch:  177 trn_loss:  0.280553936958313 val_loss:  0.34696710109710693
Epoch:  178 trn_loss:  0.27955877780914307 val_loss:  0.3456572890281677
Epoch:  179 trn_loss:  0.27857735753059387 val_loss:  0.34439265727996826
Epoch:  180 trn_loss:  0.2776064872741699 val_loss:  0.3431705832481384
Epoch:  181 trn_loss:  0.27664247155189514 val_loss:  0.34197792410850525
Epoch:  182 trn_loss:  0.2756846845149994 val_loss:  0.3407961428165436
Epoch:  183 trn_loss:  0.27473515272140503 val_loss:  0.3396056890487671
Epoch:  184 trn_loss:  0.273796021938324 val_loss:  0.3383926749229431
Epoch:  185 trn_loss:  0.2728683352470398 val_loss:  0.3371535837650299
Epoch:  186 trn_loss:  0.2719520330429077 val_loss:  0.3358972668647766
Epoch:  187 trn_loss:  0.27104705572128296 val_loss:  0.3346412777900696
Epoch:  188 trn_loss:  0.27015355229377747 val_loss:  0.33340537548065186
Epoch:  189 trn_loss:  0.2692714035511017 val_loss:  0.33220624923706055
Epoch:  190 trn_loss:  0.26840025186538696 val_loss:  0.33105289936065674
Epoch:  191 trn_loss:  0.267539918422699 val_loss:  0.3299438953399658
Epoch:  192 trn_loss:  0.26669037342071533 val_loss:  0.32886916399002075
Epoch:  193 trn_loss:  0.2658521234989166 val_loss:  0.3278142213821411
Epoch:  194 trn_loss:  0.26502561569213867 val_loss:  0.32676389813423157
Epoch:  195 trn_loss:  0.2642109990119934 val_loss:  0.3257078230381012
Epoch:  196 trn_loss:  0.2634079158306122 val_loss:  0.32464319467544556
Epoch:  197 trn_loss:  0.26261574029922485 val_loss:  0.323574423789978
Epoch:  198 trn_loss:  0.2618335783481598 val_loss:  0.3225111961364746
Epoch:  199 trn_loss:  0.26106131076812744 val_loss:  0.3214658498764038
Epoch:  200 trn_loss:  0.260299414396286 val_loss:  0.3204474449157715
Epoch:  201 trn_loss:  0.2595480978488922 val_loss:  0.3194591701030731
Epoch:  202 trn_loss:  0.2588074505329132 val_loss:  0.31849780678749084
Epoch:  203 trn_loss:  0.25807708501815796 val_loss:  0.31755539774894714
Epoch:  204 trn_loss:  0.25735655426979065 val_loss:  0.31662410497665405
Epoch:  205 trn_loss:  0.256645530462265 val_loss:  0.31569719314575195
Epoch:  206 trn_loss:  0.25594431161880493 val_loss:  0.3147718608379364
Epoch:  207 trn_loss:  0.2552528381347656 val_loss:  0.3138481080532074
Epoch:  208 trn_loss:  0.25457102060317993 val_loss:  0.3129296600818634
Epoch:  209 trn_loss:  0.2538984715938568 val_loss:  0.3120215833187103
Epoch:  210 trn_loss:  0.2532349228858948 val_loss:  0.3111284375190735
Epoch:  211 trn_loss:  0.2525802552700043 val_loss:  0.3102525770664215
Epoch:  212 trn_loss:  0.2519344389438629 val_loss:  0.30939391255378723
Epoch:  213 trn_loss:  0.25129762291908264 val_loss:  0.3085498809814453
Epoch:  214 trn_loss:  0.25066959857940674 val_loss:  0.3077169954776764
Epoch:  215 trn_loss:  0.25005027651786804 val_loss:  0.30689170956611633
Epoch:  216 trn_loss:  0.2494392991065979 val_loss:  0.30607253313064575
Epoch:  217 trn_loss:  0.24883651733398438 val_loss:  0.3052597641944885
Epoch:  218 trn_loss:  0.24824172258377075 val_loss:  0.30445486307144165
Epoch:  219 trn_loss:  0.2476547509431839 val_loss:  0.303660124540329
Epoch:  220 trn_loss:  0.24707582592964172 val_loss:  0.3028777539730072
Epoch:  221 trn_loss:  0.2465049922466278 val_loss:  0.30210939049720764
Epoch:  222 trn_loss:  0.24594205617904663 val_loss:  0.30135560035705566
Epoch:  223 trn_loss:  0.24538704752922058 val_loss:  0.300615519285202
Epoch:  224 trn_loss:  0.2448398321866989 val_loss:  0.29988741874694824
Epoch:  225 trn_loss:  0.24430020153522491 val_loss:  0.2991691827774048
Epoch:  226 trn_loss:  0.24376805126667023 val_loss:  0.298459529876709
Epoch:  227 trn_loss:  0.24324321746826172 val_loss:  0.2977577745914459
Epoch:  228 trn_loss:  0.24272526800632477 val_loss:  0.29706427454948425
Epoch:  229 trn_loss:  0.24221399426460266 val_loss:  0.2963797450065613
Epoch:  230 trn_loss:  0.24170906841754913 val_loss:  0.2957046926021576
Epoch:  231 trn_loss:  0.2412106692790985 val_loss:  0.29503926634788513
Epoch:  232 trn_loss:  0.2407185286283493 val_loss:  0.29438310861587524
Epoch:  233 trn_loss:  0.24023236334323883 val_loss:  0.293735146522522
Epoch:  234 trn_loss:  0.23975150287151337 val_loss:  0.29309383034706116
Epoch:  235 trn_loss:  0.23927456140518188 val_loss:  0.2924582064151764
Epoch:  236 trn_loss:  0.23880068957805634 val_loss:  0.2918306887149811
Epoch:  237 trn_loss:  0.2383321076631546 val_loss:  0.2912164330482483
Epoch:  238 trn_loss:  0.23787447810173035 val_loss:  0.2906109392642975
Epoch:  239 trn_loss:  0.2374216765165329 val_loss:  0.2900112271308899
Epoch:  240 trn_loss:  0.23697103559970856 val_loss:  0.2894227206707001
Epoch:  241 trn_loss:  0.2365272045135498 val_loss:  0.28884655237197876
Epoch:  242 trn_loss:  0.2360907644033432 val_loss:  0.2882799506187439
Epoch:  243 trn_loss:  0.23565970361232758 val_loss:  0.28772059082984924
Epoch:  244 trn_loss:  0.2352324277162552 val_loss:  0.28716790676116943
Epoch:  245 trn_loss:  0.2348095029592514 val_loss:  0.2866230607032776
Epoch:  246 trn_loss:  0.23439253866672516 val_loss:  0.2860864996910095
Epoch:  247 trn_loss:  0.23398123681545258 val_loss:  0.2855576276779175
Epoch:  248 trn_loss:  0.233573779463768 val_loss:  0.285037636756897
Epoch:  249 trn_loss:  0.23316925764083862 val_loss:  0.284528523683548
[250]: 50.0% complete: 
Epoch:  250 trn_loss:  0.2327679991722107 val_loss:  0.28403207659721375
Epoch:  251 trn_loss:  0.2323702722787857 val_loss:  0.28354620933532715
Epoch:  252 trn_loss:  0.23197528719902039 val_loss:  0.2830697298049927
Epoch:  253 trn_loss:  0.23158219456672668 val_loss:  0.282601535320282
Epoch:  254 trn_loss:  0.2311909794807434 val_loss:  0.2821401059627533
Epoch:  255 trn_loss:  0.23080195486545563 val_loss:  0.2816833257675171
Epoch:  256 trn_loss:  0.2304149717092514 val_loss:  0.28122881054878235
Epoch:  257 trn_loss:  0.2300291657447815 val_loss:  0.28077518939971924
Epoch:  258 trn_loss:  0.22964300215244293 val_loss:  0.2803206443786621
Epoch:  259 trn_loss:  0.22925329208374023 val_loss:  0.2798692286014557
Epoch:  260 trn_loss:  0.2288600355386734 val_loss:  0.2794346213340759
Epoch:  261 trn_loss:  0.22846929728984833 val_loss:  0.2790296971797943
Epoch:  262 trn_loss:  0.2280825972557068 val_loss:  0.278649240732193
Epoch:  263 trn_loss:  0.22769764065742493 val_loss:  0.27830052375793457
Epoch:  264 trn_loss:  0.22732199728488922 val_loss:  0.2779746651649475
Epoch:  265 trn_loss:  0.2269543558359146 val_loss:  0.2776376008987427
Epoch:  266 trn_loss:  0.22657908499240875 val_loss:  0.277260422706604
Epoch:  267 trn_loss:  0.22618988156318665 val_loss:  0.27682507038116455
Epoch:  268 trn_loss:  0.22579574584960938 val_loss:  0.2763254642486572
Epoch:  269 trn_loss:  0.22540681064128876 val_loss:  0.27577611804008484
Epoch:  270 trn_loss:  0.22501474618911743 val_loss:  0.27526578307151794
Epoch:  271 trn_loss:  0.22463136911392212 val_loss:  0.27487221360206604
Epoch:  272 trn_loss:  0.22427839040756226 val_loss:  0.2743696868419647
Epoch:  273 trn_loss:  0.2239188551902771 val_loss:  0.27382218837738037
Epoch:  274 trn_loss:  0.2235596925020218 val_loss:  0.2733396887779236
Epoch:  275 trn_loss:  0.22320397198200226 val_loss:  0.27296075224876404
Epoch:  276 trn_loss:  0.22284701466560364 val_loss:  0.272659569978714
Epoch:  277 trn_loss:  0.22249537706375122 val_loss:  0.2723621129989624
Epoch:  278 trn_loss:  0.2221439629793167 val_loss:  0.27201107144355774
Epoch:  279 trn_loss:  0.22178782522678375 val_loss:  0.2716151773929596
Epoch:  280 trn_loss:  0.2214370220899582 val_loss:  0.27121075987815857
Epoch:  281 trn_loss:  0.22109352052211761 val_loss:  0.2708204984664917
Epoch:  282 trn_loss:  0.22074797749519348 val_loss:  0.27044957876205444
Epoch:  283 trn_loss:  0.2204030305147171 val_loss:  0.2700798809528351
Epoch:  284 trn_loss:  0.22006388008594513 val_loss:  0.2696859538555145
Epoch:  285 trn_loss:  0.21972449123859406 val_loss:  0.26927322149276733
Epoch:  286 trn_loss:  0.21938364207744598 val_loss:  0.2688736617565155
Epoch:  287 trn_loss:  0.21904607117176056 val_loss:  0.26850345730781555
Epoch:  288 trn_loss:  0.21871036291122437 val_loss:  0.26815009117126465
Epoch:  289 trn_loss:  0.21837559342384338 val_loss:  0.26778271794319153
Epoch:  290 trn_loss:  0.2180442214012146 val_loss:  0.26737546920776367
Epoch:  291 trn_loss:  0.217713862657547 val_loss:  0.26694050431251526
Epoch:  292 trn_loss:  0.2173829823732376 val_loss:  0.2665241062641144
Epoch:  293 trn_loss:  0.21705439686775208 val_loss:  0.266161173582077
Epoch:  294 trn_loss:  0.21672731637954712 val_loss:  0.2658465802669525
Epoch:  295 trn_loss:  0.21640050411224365 val_loss:  0.2655402421951294
Epoch:  296 trn_loss:  0.21607576310634613 val_loss:  0.26519641280174255
Epoch:  297 trn_loss:  0.21575210988521576 val_loss:  0.2648051977157593
Epoch:  298 trn_loss:  0.21542836725711823 val_loss:  0.26440155506134033
Epoch:  299 trn_loss:  0.21510659158229828 val_loss:  0.2640308737754822
Epoch:  300 trn_loss:  0.2147865891456604 val_loss:  0.26371273398399353
Epoch:  301 trn_loss:  0.21446694433689117 val_loss:  0.2634320855140686
Epoch:  302 trn_loss:  0.21414814889431 val_loss:  0.26315173506736755
Epoch:  303 trn_loss:  0.21383033692836761 val_loss:  0.2628442049026489
Epoch:  304 trn_loss:  0.21351349353790283 val_loss:  0.2625076174736023
Epoch:  305 trn_loss:  0.21319808065891266 val_loss:  0.2621603012084961
Epoch:  306 trn_loss:  0.21288315951824188 val_loss:  0.2618192434310913
Epoch:  307 trn_loss:  0.21256844699382782 val_loss:  0.2614860534667969
Epoch:  308 trn_loss:  0.2122551053762436 val_loss:  0.261149525642395
Epoch:  309 trn_loss:  0.2119426131248474 val_loss:  0.2608077824115753
Epoch:  310 trn_loss:  0.2116304188966751 val_loss:  0.260470986366272
Epoch:  311 trn_loss:  0.21131953597068787 val_loss:  0.26014530658721924
Epoch:  312 trn_loss:  0.21100953221321106 val_loss:  0.2598287761211395
Epoch:  313 trn_loss:  0.2106999307870865 val_loss:  0.2595067024230957
Epoch:  314 trn_loss:  0.21039125323295593 val_loss:  0.259169340133667
Epoch:  315 trn_loss:  0.21008294820785522 val_loss:  0.2588276267051697
Epoch:  316 trn_loss:  0.20977510511875153 val_loss:  0.25850459933280945
Epoch:  317 trn_loss:  0.20946796238422394 val_loss:  0.2582058310508728
Epoch:  318 trn_loss:  0.20916146039962769 val_loss:  0.2579147219657898
Epoch:  319 trn_loss:  0.20885591208934784 val_loss:  0.25760617852211
Epoch:  320 trn_loss:  0.20855110883712769 val_loss:  0.2572757303714752
Epoch:  321 trn_loss:  0.20824721455574036 val_loss:  0.25694411993026733
Epoch:  322 trn_loss:  0.2079445719718933 val_loss:  0.2566297650337219
Epoch:  323 trn_loss:  0.2076428234577179 val_loss:  0.2563343048095703
Epoch:  324 trn_loss:  0.20734184980392456 val_loss:  0.25603586435317993
Epoch:  325 trn_loss:  0.20704220235347748 val_loss:  0.2557259798049927
Epoch:  326 trn_loss:  0.20674359798431396 val_loss:  0.2554151713848114
Epoch:  327 trn_loss:  0.20644627511501312 val_loss:  0.25511470437049866
Epoch:  328 trn_loss:  0.20615015923976898 val_loss:  0.25481754541397095
Epoch:  329 trn_loss:  0.20585542917251587 val_loss:  0.254513144493103
Epoch:  330 trn_loss:  0.2055618166923523 val_loss:  0.2542076110839844
Epoch:  331 trn_loss:  0.20526950061321259 val_loss:  0.25391685962677
Epoch:  332 trn_loss:  0.20497849583625793 val_loss:  0.25364789366722107
Epoch:  333 trn_loss:  0.20468848943710327 val_loss:  0.2533756196498871
Epoch:  334 trn_loss:  0.20439974963665009 val_loss:  0.25308457016944885
Epoch:  335 trn_loss:  0.20411206781864166 val_loss:  0.2527826428413391
Epoch:  336 trn_loss:  0.20382563769817352 val_loss:  0.2524893581867218
Epoch:  337 trn_loss:  0.2035403549671173 val_loss:  0.2522111237049103
Epoch:  338 trn_loss:  0.20325632393360138 val_loss:  0.251936137676239
Epoch:  339 trn_loss:  0.20297357439994812 val_loss:  0.2516572177410126
Epoch:  340 trn_loss:  0.2026919424533844 val_loss:  0.2513802945613861
Epoch:  341 trn_loss:  0.20241157710552216 val_loss:  0.2511087656021118
Epoch:  342 trn_loss:  0.202132448554039 val_loss:  0.2508338391780853
Epoch:  343 trn_loss:  0.20185457170009613 val_loss:  0.25054723024368286
Epoch:  344 trn_loss:  0.20157787203788757 val_loss:  0.2502579391002655
Epoch:  345 trn_loss:  0.2013024538755417 val_loss:  0.24998235702514648
Epoch:  346 trn_loss:  0.2010282725095749 val_loss:  0.24972325563430786
Epoch:  347 trn_loss:  0.2007552683353424 val_loss:  0.24946658313274384
Epoch:  348 trn_loss:  0.20048342645168304 val_loss:  0.24920010566711426
Epoch:  349 trn_loss:  0.2002127766609192 val_loss:  0.2489280104637146
Epoch:  350 trn_loss:  0.19994327425956726 val_loss:  0.24866002798080444
Epoch:  351 trn_loss:  0.19967490434646606 val_loss:  0.24839621782302856
Epoch:  352 trn_loss:  0.19940762221813202 val_loss:  0.24813175201416016
Epoch:  353 trn_loss:  0.19914144277572632 val_loss:  0.24786800146102905
Epoch:  354 trn_loss:  0.19887645542621613 val_loss:  0.247611865401268
Epoch:  355 trn_loss:  0.19861255586147308 val_loss:  0.24736247956752777
Epoch:  356 trn_loss:  0.19834984838962555 val_loss:  0.24710924923419952
Epoch:  357 trn_loss:  0.19808828830718994 val_loss:  0.2468482106924057
Epoch:  358 trn_loss:  0.19782783091068268 val_loss:  0.24658840894699097
Epoch:  359 trn_loss:  0.19756847620010376 val_loss:  0.24633683264255524
Epoch:  360 trn_loss:  0.197310209274292 val_loss:  0.24608983099460602
Epoch:  361 trn_loss:  0.19705307483673096 val_loss:  0.245840385556221
Epoch:  362 trn_loss:  0.19679692387580872 val_loss:  0.24558967351913452
Epoch:  363 trn_loss:  0.19654175639152527 val_loss:  0.24534188210964203
Epoch:  364 trn_loss:  0.19628755748271942 val_loss:  0.2450958639383316
Epoch:  365 trn_loss:  0.196034237742424 val_loss:  0.2448493093252182
Epoch:  366 trn_loss:  0.19578175246715546 val_loss:  0.24460111558437347
Epoch:  367 trn_loss:  0.19553016126155853 val_loss:  0.24435578286647797
Epoch:  368 trn_loss:  0.19527946412563324 val_loss:  0.24411651492118835
Epoch:  369 trn_loss:  0.19502955675125122 val_loss:  0.2438783049583435
Epoch:  370 trn_loss:  0.19478055834770203 val_loss:  0.2436365783214569
Epoch:  371 trn_loss:  0.19453251361846924 val_loss:  0.24339379370212555
Epoch:  372 trn_loss:  0.1942853182554245 val_loss:  0.2431551069021225
Epoch:  373 trn_loss:  0.19403891265392303 val_loss:  0.24291695654392242
Epoch:  374 trn_loss:  0.1937933713197708 val_loss:  0.24267594516277313
[375]: 75.0% complete: 
Epoch:  375 trn_loss:  0.19354864954948425 val_loss:  0.2424352765083313
Epoch:  376 trn_loss:  0.19330467283725739 val_loss:  0.24220016598701477
Epoch:  377 trn_loss:  0.19306141138076782 val_loss:  0.24196521937847137
Epoch:  378 trn_loss:  0.19281889498233795 val_loss:  0.24173098802566528
Epoch:  379 trn_loss:  0.19257688522338867 val_loss:  0.24149714410305023
Epoch:  380 trn_loss:  0.19233542680740356 val_loss:  0.2412656992673874
Epoch:  381 trn_loss:  0.19209453463554382 val_loss:  0.24103514850139618
Epoch:  382 trn_loss:  0.19185423851013184 val_loss:  0.24080201983451843
Epoch:  383 trn_loss:  0.19161449372768402 val_loss:  0.24056822061538696
Epoch:  384 trn_loss:  0.19137562811374664 val_loss:  0.24033890664577484
Epoch:  385 trn_loss:  0.19113805890083313 val_loss:  0.24011285603046417
Epoch:  386 trn_loss:  0.19090212881565094 val_loss:  0.23988765478134155
Epoch:  387 trn_loss:  0.19066746532917023 val_loss:  0.23966200649738312
Epoch:  388 trn_loss:  0.19043302536010742 val_loss:  0.2394344061613083
Epoch:  389 trn_loss:  0.1901978850364685 val_loss:  0.23920194804668427
Epoch:  390 trn_loss:  0.1899610310792923 val_loss:  0.23896577954292297
Epoch:  391 trn_loss:  0.18972285091876984 val_loss:  0.23873156309127808
Epoch:  392 trn_loss:  0.18948638439178467 val_loss:  0.2385076880455017
Epoch:  393 trn_loss:  0.1892559826374054 val_loss:  0.23829428851604462
Epoch:  394 trn_loss:  0.1890343278646469 val_loss:  0.2380877584218979
Epoch:  395 trn_loss:  0.18881294131278992 val_loss:  0.23789480328559875
Epoch:  396 trn_loss:  0.18858793377876282 val_loss:  0.23769930005073547
Epoch:  397 trn_loss:  0.18836161494255066 val_loss:  0.2374742329120636
Epoch:  398 trn_loss:  0.18813608586788177 val_loss:  0.23721234500408173
Epoch:  399 trn_loss:  0.1879124492406845 val_loss:  0.2369389533996582
Epoch:  400 trn_loss:  0.18769139051437378 val_loss:  0.23668262362480164
Epoch:  401 trn_loss:  0.18747246265411377 val_loss:  0.23645438253879547
Epoch:  402 trn_loss:  0.1872546523809433 val_loss:  0.23624952137470245
Epoch:  403 trn_loss:  0.18703687191009521 val_loss:  0.23605674505233765
Epoch:  404 trn_loss:  0.1868191361427307 val_loss:  0.23585844039916992
Epoch:  405 trn_loss:  0.1866016834974289 val_loss:  0.23564380407333374
Epoch:  406 trn_loss:  0.1863844096660614 val_loss:  0.2354169487953186
Epoch:  407 trn_loss:  0.18616719543933868 val_loss:  0.23519253730773926
Epoch:  408 trn_loss:  0.18595176935195923 val_loss:  0.23498374223709106
Epoch:  409 trn_loss:  0.18573901057243347 val_loss:  0.2347954660654068
Epoch:  410 trn_loss:  0.1855253428220749 val_loss:  0.23461133241653442
Epoch:  411 trn_loss:  0.18531186878681183 val_loss:  0.23439818620681763
Epoch:  412 trn_loss:  0.18509967625141144 val_loss:  0.23414193093776703
Epoch:  413 trn_loss:  0.18488802015781403 val_loss:  0.2338661402463913
Epoch:  414 trn_loss:  0.18467681109905243 val_loss:  0.23360903561115265
Epoch:  415 trn_loss:  0.1844668984413147 val_loss:  0.23338620364665985
Epoch:  416 trn_loss:  0.18425790965557098 val_loss:  0.23319271206855774
Epoch:  417 trn_loss:  0.18404950201511383 val_loss:  0.23300528526306152
Epoch:  418 trn_loss:  0.18384194374084473 val_loss:  0.23279762268066406
Epoch:  419 trn_loss:  0.18363498151302338 val_loss:  0.23256391286849976
Epoch:  420 trn_loss:  0.1834283173084259 val_loss:  0.23232361674308777
Epoch:  421 trn_loss:  0.18322235345840454 val_loss:  0.23210249841213226
Epoch:  422 trn_loss:  0.18301713466644287 val_loss:  0.2319049835205078
Epoch:  423 trn_loss:  0.1828123927116394 val_loss:  0.2317172884941101
Epoch:  424 trn_loss:  0.18260842561721802 val_loss:  0.23152095079421997
Epoch:  425 trn_loss:  0.18240493535995483 val_loss:  0.23130284249782562
Epoch:  426 trn_loss:  0.18220146000385284 val_loss:  0.23107072710990906
Epoch:  427 trn_loss:  0.18199846148490906 val_loss:  0.23084309697151184
Epoch:  428 trn_loss:  0.1817961037158966 val_loss:  0.2306332290172577
Epoch:  429 trn_loss:  0.18159405887126923 val_loss:  0.23043972253799438
Epoch:  430 trn_loss:  0.1813923716545105 val_loss:  0.23024484515190125
Epoch:  431 trn_loss:  0.18119114637374878 val_loss:  0.2300359606742859
Epoch:  432 trn_loss:  0.1809902936220169 val_loss:  0.2298157960176468
Epoch:  433 trn_loss:  0.1807897388935089 val_loss:  0.22959937155246735
Epoch:  434 trn_loss:  0.1805896759033203 val_loss:  0.22939623892307281
Epoch:  435 trn_loss:  0.18038998544216156 val_loss:  0.22920359671115875
Epoch:  436 trn_loss:  0.18019072711467743 val_loss:  0.22901242971420288
Epoch:  437 trn_loss:  0.1799919307231903 val_loss:  0.22881709039211273
Epoch:  438 trn_loss:  0.1797935962677002 val_loss:  0.22861284017562866
Epoch:  439 trn_loss:  0.17959582805633545 val_loss:  0.22840112447738647
Epoch:  440 trn_loss:  0.1793985366821289 val_loss:  0.22818748652935028
Epoch:  441 trn_loss:  0.1792016178369522 val_loss:  0.22797751426696777
Epoch:  442 trn_loss:  0.17900484800338745 val_loss:  0.22777128219604492
Epoch:  443 trn_loss:  0.17880797386169434 val_loss:  0.2275659739971161
Epoch:  444 trn_loss:  0.1786109060049057 val_loss:  0.22735942900180817
Epoch:  445 trn_loss:  0.1784133017063141 val_loss:  0.22714942693710327
Epoch:  446 trn_loss:  0.1782146394252777 val_loss:  0.2269381880760193
Epoch:  447 trn_loss:  0.17801454663276672 val_loss:  0.22673633694648743
Epoch:  448 trn_loss:  0.1778154969215393 val_loss:  0.22655323147773743
Epoch:  449 trn_loss:  0.17762120068073273 val_loss:  0.2263881415128708
Epoch:  450 trn_loss:  0.17742790281772614 val_loss:  0.22621634602546692
Epoch:  451 trn_loss:  0.17723171412944794 val_loss:  0.22600814700126648
Epoch:  452 trn_loss:  0.17703591287136078 val_loss:  0.22576133906841278
Epoch:  453 trn_loss:  0.17684142291545868 val_loss:  0.22550815343856812
Epoch:  454 trn_loss:  0.1766476035118103 val_loss:  0.22528120875358582
Epoch:  455 trn_loss:  0.17645536363124847 val_loss:  0.2250857949256897
Epoch:  456 trn_loss:  0.17626416683197021 val_loss:  0.22490385174751282
Epoch:  457 trn_loss:  0.17607294023036957 val_loss:  0.22471295297145844
Epoch:  458 trn_loss:  0.1758810132741928 val_loss:  0.22450202703475952
Epoch:  459 trn_loss:  0.1756894290447235 val_loss:  0.22427566349506378
Epoch:  460 trn_loss:  0.17549946904182434 val_loss:  0.22405284643173218
Epoch:  461 trn_loss:  0.17531032860279083 val_loss:  0.2238508015871048
Epoch:  462 trn_loss:  0.17512105405330658 val_loss:  0.22366595268249512
Epoch:  463 trn_loss:  0.17493221163749695 val_loss:  0.22347603738307953
Epoch:  464 trn_loss:  0.17474424839019775 val_loss:  0.22326765954494476
Epoch:  465 trn_loss:  0.1745568960905075 val_loss:  0.223037987947464
Epoch:  466 trn_loss:  0.17437073588371277 val_loss:  0.22281275689601898
Epoch:  467 trn_loss:  0.17418527603149414 val_loss:  0.2226175218820572
Epoch:  468 trn_loss:  0.17400020360946655 val_loss:  0.22244416177272797
Epoch:  469 trn_loss:  0.17381562292575836 val_loss:  0.22226789593696594
Epoch:  470 trn_loss:  0.17363184690475464 val_loss:  0.22207652032375336
Epoch:  471 trn_loss:  0.1734485924243927 val_loss:  0.2218659371137619
Epoch:  472 trn_loss:  0.1732657551765442 val_loss:  0.22164830565452576
Epoch:  473 trn_loss:  0.17308367788791656 val_loss:  0.2214433252811432
Epoch:  474 trn_loss:  0.17290203273296356 val_loss:  0.22125384211540222
Epoch:  475 trn_loss:  0.17272073030471802 val_loss:  0.22106535732746124
Epoch:  476 trn_loss:  0.17254012823104858 val_loss:  0.22086644172668457
Epoch:  477 trn_loss:  0.1723601073026657 val_loss:  0.22065678238868713
Epoch:  478 trn_loss:  0.1721806526184082 val_loss:  0.22044801712036133
Epoch:  479 trn_loss:  0.1720017045736313 val_loss:  0.22025033831596375
Epoch:  480 trn_loss:  0.17182311415672302 val_loss:  0.22006283700466156
Epoch:  481 trn_loss:  0.17164508998394012 val_loss:  0.21987499296665192
Epoch:  482 trn_loss:  0.1714676022529602 val_loss:  0.2196822315454483
Epoch:  483 trn_loss:  0.17129066586494446 val_loss:  0.21948617696762085
Epoch:  484 trn_loss:  0.17111413180828094 val_loss:  0.21928857266902924
Epoch:  485 trn_loss:  0.17093804478645325 val_loss:  0.21909268200397491
Epoch:  486 trn_loss:  0.17076237499713898 val_loss:  0.21890264749526978
Epoch:  487 trn_loss:  0.17058725655078888 val_loss:  0.2187121957540512
Epoch:  488 trn_loss:  0.17041254043579102 val_loss:  0.21851617097854614
Epoch:  489 trn_loss:  0.170238196849823 val_loss:  0.21832245588302612
Epoch:  490 trn_loss:  0.1700642853975296 val_loss:  0.21813243627548218
Epoch:  491 trn_loss:  0.16989077627658844 val_loss:  0.21794255077838898
Epoch:  492 trn_loss:  0.1697176694869995 val_loss:  0.21775567531585693
Epoch:  493 trn_loss:  0.1695450097322464 val_loss:  0.21756310760974884
Epoch:  494 trn_loss:  0.1693727970123291 val_loss:  0.21736563742160797
Epoch:  495 trn_loss:  0.16920095682144165 val_loss:  0.217173770070076
Epoch:  496 trn_loss:  0.16902948915958405 val_loss:  0.21698927879333496
Epoch:  497 trn_loss:  0.16885840892791748 val_loss:  0.21680626273155212
Epoch:  498 trn_loss:  0.16868776082992554 val_loss:  0.21662037074565887
Epoch:  499 trn_loss:  0.16851751506328583 val_loss:  0.21642708778381348
Training time = 150.8094642162323 seconds

Plotting: MSE Loss for Training and Validation

In order to understand how well the model has trained we plot the training loss and validation loss as a function of Epoch in Figure 2. Figure 2 shows the MSE loss for training (blue) and validation (orange) as a function of epoch.

In [11]:
plt.plot(np.arange(len(train_loss[:])), train_loss[:], color="blue")
plt.plot(np.arange(len(train_loss[:])), valdn_loss[:], color="orange")
plt.xlabel('Epoch')
plt.ylabel('MSE Loss')
plt.show()

Figure 2: Training and Validation MSE loss (blue, orange) as a function of Epoch.


Step 4: Testing the Model

Now that the model has been trained, testing the model is a computationally cheap proceedure. As before, we choose the data using DEMdata, and load with DataLoader. Using valtest_model, the DeepEM map is created ${\texttt{output = model(img)}}$, and the MSE loss calculated as during training.

In [12]:
dem_data_test = DEMdata(X_train, y_train, X_test, y_test, X_val, y_val, split='test')
dem_loader = DataLoader(dem_data_test, batch_size=1)

t0=time.time() #Timing how long it takes to predict the DEMs
dummy, test_loss, dem_pred, dem_in_test = valtest_model(dem_loader, criterion)
performance = "Number of DEM solutions per second = {0}".format((y_test.shape[2]*y_test.shape[3])/(time.time()-t0))

print(performance)
Number of DEM solutions per second = 3130834.8456553183

Plotting: AIA, Basis Pursuit, DeepEM

With the DeepEM map calculated, we can now compare the solutions obtained by Basis Pursuit and DeepEM. Figure 3 is similar to Figure 1 with an additional row corresponding to the solutions for DeepEM. Figure 3 shows SDO/AIA images in 171 Ã…, 211 Ã…, and 94 Ã… (top), with the corresponding DEM bins from Basis Pursuit (chosen at the peak sensitivity of each of the SDO/AIA channels) shown below (middle). The bottom row shows the DeepEM solutions that correspond to the same bins as the Basis Pursuit solutions. DeepEM provides solutions that are similar to Basis Pursuit, but importantly, provides DEM solutions for every pixel.

In [13]:
fig,ax=plt.subplots(ncols=3,figsize=(9*2,9))

ax[1].imshow(X_test[0,4,:,:],vmin=0.25,vmax=25,cmap='Greys_r')
ax[1].text(5, 512.-7.5, '${\it SDO}$/AIA 211 $\AA$', color="white", size='large')
ax[0].imshow(X_test[0,2,:,:],vmin=0.01,vmax=30,cmap='Greys_r')
ax[0].text(5, 512.-7.5, '${\it SDO}$/AIA 171 $\AA$', color="white", size='large')
ax[2].imshow(X_test[0,0,:,:],vmin=0.01,vmax=3,cmap='Greys_r')
ax[2].text(5, 512.-7.5, '${\it SDO}$/AIA 94 $\AA$', color="white", size='large')

for axes in ax:
    axes.get_xaxis().set_visible(False)
    axes.get_yaxis().set_visible(False)
    
fig,ax=plt.subplots(ncols=3,figsize=(9*2,9))

ax[1].imshow(dem_in_test[0,8,:,:].cpu().detach().numpy(),vmin=0.25,vmax=10,cmap='viridis')
ax[1].text(5, 512.-7.5, 'Basis Pursuit DEM log$_{10}$T ~ 6.3', color="white", size='large')
#ax[1].scatter(x=[400], y=[80], c='w', s=50, marker='x') 
ax[0].imshow(dem_in_test[0,4,:,:].cpu().detach().numpy(),vmin=0.01,vmax=3,cmap='viridis')
ax[0].text(5, 512.-7.5, 'Basis Pursuit DEM log$_{10}$T ~ 5.9', color="white", size='large')
ax[2].imshow(dem_in_test[0,15,:,:].cpu().detach().numpy(),vmin=0.01,vmax=3,cmap='viridis')
ax[2].text(5, 512.-7.5, 'Basis Pursuit DEM log$_{10}$T ~ 7.0', color="white", size='large')

for axes in ax:
    axes.get_xaxis().set_visible(False)
    axes.get_yaxis().set_visible(False)
    
fig,ax=plt.subplots(ncols=3,figsize=(9*2,9))

ax[1].imshow(dem_pred[0,8,:,:].cpu().detach().numpy(),vmin=0.25,vmax=10,cmap='viridis')
ax[1].text(5, 512.-7.5, 'DeepEM log$_{10}$T ~ 6.3', color="white", size='large')
#ax[1].scatter(x=[400], y=[80], c='w', s=50, marker='x') 
ax[0].imshow(dem_pred[0,4,:,:].cpu().detach().numpy(),vmin=0.01,vmax=3,cmap='viridis')
ax[0].text(5, 512.-7.5, 'DeepEM log$_{10}$T ~ 5.9', color="white", size='large')
ax[2].imshow(dem_pred[0,15,:,:].cpu().detach().numpy(),vmin=0.01,vmax=3,cmap='viridis')
ax[2].text(5, 512.-7.5, 'DeepEM log$_{10}$T ~ 7.0', color="white", size='large')

for axes in ax:
    axes.get_xaxis().set_visible(False)
    axes.get_yaxis().set_visible(False)

Figure 3: Left to Right: SDO/AIA images in 171 Ã…, 211 Ã…, and 94 Ã… (top), with the corresponding DEM bins from Basis Pursuit (chosen at the peak sensitivity of each of the SDO/AIA channels) shown below (middle). The bottom row shows the DeepEM solutions that correspond to the same bins as the Basis Pursuit solutions. DeepEM provides solutions that are similar to Basis Pursuit, but importantly, provides DEM solutions for every pixel.


Furthermore, as we have the original Basis Pursuit DEM solutions ("the ground truth"), we can compare the average DEM from Basis Pursuit to the average DEM from DeepEM, as they should be similar. Figure 4 shows the average Basis Pursuit DEM (black curve) and the DeepEM solution (light blue bars/dotted line).

In [14]:
def PlotTotalEM(em_unscaled, em_pred_unscaled, lgtaxis, status):
    mask = np.zeros([status.shape[0],status.shape[1]])
    mask[np.where(status == 0.0)] = 1.0
    nmask = np.sum(mask)
    
    EM_tru_sum = np.zeros([lgtaxis.size])
    EM_inv_sum = np.zeros([lgtaxis.size])
    
    for i in range(lgtaxis.size):
        EM_tru_sum[i] = np.sum(em_unscaled[0,i,:,:]*mask)/nmask
        EM_inv_sum[i] = np.sum(em_pred_unscaled[0,i,:,:]*mask)/nmask
        
    fig = plt.figure   
    plt.plot(lgtaxis,EM_tru_sum, linewidth=3, color="black")
    plt.plot(lgtaxis,EM_inv_sum, linewidth=3, color="lightblue", linestyle='--')
    plt.tick_params(axis='both', which='major')#, labelsize=16)
    plt.tick_params(axis='both', which='minor')#, labelsize=16)
    
    dlogT = lgtaxis[1]-lgtaxis[0]
    plt.bar(lgtaxis-0.5*dlogT, EM_inv_sum, dlogT, linewidth=2, color='lightblue')
    
    plt.xlim(lgtaxis[0]-0.5*dlogT, lgtaxis.max()+0.5*dlogT)
    plt.xticks(np.arange(np.min(lgtaxis), np.max(lgtaxis),2*dlogT))
    plt.ylim(1e24,1e27)
    plt.yscale('log')
    plt.xlabel('log$_{10}$T [K]')
    plt.ylabel('Mean Emission Measure [cm$^{-5}$]')
    plt.title('Basis Pursuit (curve) vs. DeepEM (bars)')
    
    plt.show()
    return EM_inv_sum, EM_tru_sum
In [15]:
em_unscaled = em_unscale(dem_in_test.detach().cpu().numpy())
em_pred_unscaled = em_unscale(dem_pred.detach().cpu().numpy())
status = np.zeros([512,512]) #Setting statuses to zero here, but could be provided
                   
EMinv, EMTru = PlotTotalEM(em_unscaled,em_pred_unscaled,lgtaxis,status)

Figure 4: Average Basis Pursuit DEM (black line) against the Average DeepEM solution (light blue bars/dotted line). It is clear that this simple implementation of DeepEM provides, on average, DEMs that are similar to Basis Pursuit (Cheung et al 2015).


Step 5: Synthesize SDO/AIA Observations

Finally, it is also of interest to reconstruct the SDO/AIA observations from both the Basis Pursuit, and DeepEM solutions.

We are able to pose the problem of reconstructing the SDO/AIA observations from the DEM as a 1x1 2D Convolution. We first define the weights as the response functions of each channel, and set the biases to $zero$. By convolving the unscaled DEM at each pixel with the 6 filters (one for each SDO/AIA response function), we can recover the SDO/AIA observations.

In [16]:
# We first load the AIA response functions:
cl = np.load('./DeepEM_Data/chianti_lines_AIA.npy')
In [17]:
# Used Conv2d to convolve?? every pixel (18x1x1) by the 6 response functions
# to return a set of observed fluxes in each channel (6x1x1)
dem2aia = nn.Conv2d(18, 6, kernel_size=1).cuda()

chianti_lines_2 = torch.zeros(6,18,1,1).cuda()
biases = torch.zeros(6).cuda()

# set the weights to each of the SDO/AIA response functions and biases to zero
for i, p in enumerate(dem2aia.parameters()):
    if i == 0:
        p.data = Variable(torch.from_numpy(cl).type(torch.cuda.FloatTensor))
    else:
        p.data = biases 
In [18]:
AIA_out = img_scale(dem2aia(Variable(em_unscale(dem_in_test))).detach().cpu().numpy())
AIA_out_DeepEM = img_scale(dem2aia(Variable(em_unscale(dem_pred))).detach().cpu().numpy())

Plotting SDO/AIA Observations and Synthetic Observations

In [19]:
fig,ax=plt.subplots(ncols=3,figsize=(9*2,9))

ax[1].imshow(X_test[0,4,:,:],vmin=0.25,vmax=25,cmap='Greys_r')
ax[1].text(5, 512.-7.5, '${\it SDO}$/AIA 211 $\AA$', color="white", size='large')
ax[0].imshow(X_test[0,2,:,:],vmin=0.01,vmax=30,cmap='Greys_r')
ax[0].text(5, 512.-7.5, '${\it SDO}$/AIA 171 $\AA$', color="white", size='large')
ax[2].imshow(X_test[0,0,:,:],vmin=0.01,vmax=3,cmap='Greys_r')
ax[2].text(5, 512.-7.5, '${\it SDO}$/AIA 94 $\AA$', color="white", size='large')

for axes in ax:
    axes.get_xaxis().set_visible(False)
    axes.get_yaxis().set_visible(False)
    
fig,ax=plt.subplots(ncols=3,figsize=(9*2,9))

ax[1].imshow(AIA_out[0,4,:,:],vmin=0.25,vmax=25,cmap='Greys_r')
ax[1].text(5, 512.-7.5, 'Basis Pursuit Synthesized 211 $\AA$', color="white", size='large')
ax[0].imshow(AIA_out[0,2,:,:],vmin=0.01,vmax=30,cmap='Greys_r')
ax[0].text(5, 512.-7.5, 'Basis Pursuit Synthesized 171 $\AA$', color="white", size='large')
ax[2].imshow(AIA_out[0,0,:,:],vmin=0.01,vmax=3,cmap='Greys_r')
ax[2].text(5, 512.-7.5, 'Basis Pursuit Synthesized 94 $\AA$', color="white", size='large')

for axes in ax:
    axes.get_xaxis().set_visible(False)
    axes.get_yaxis().set_visible(False)
    
fig,ax=plt.subplots(ncols=3,figsize=(9*2,9))

ax[1].imshow(AIA_out_DeepEM[0,4,:,:],vmin=0.25,vmax=25,cmap='Greys_r')
ax[1].text(5, 512.-7.5, 'DeepEM Synthesized 211 $\AA$', color="white", size='large')
ax[0].imshow(AIA_out_DeepEM[0,2,:,:],vmin=0.01,vmax=30,cmap='Greys_r')
ax[0].text(5, 512.-7.5, 'DeepEM Synthesized 171 $\AA$', color="white", size='large')
ax[2].imshow(AIA_out_DeepEM[0,0,:,:],vmin=0.01,vmax=3,cmap='Greys_r')
ax[2].text(5, 512.-7.5, 'DeepEM Synthesized 94 $\AA$', color="white", size='large')

for axes in ax:
    axes.get_xaxis().set_visible(False)
    axes.get_yaxis().set_visible(False)
    

Figure 5: Left to Right: SDO/AIA images in 171 Ã…, 211 Ã…, and 94 Ã… (top) with the corresponding synthesised observations from Basis Pursuit (middle) and DeepEM (bottom). DeepEM provides synthetic observations that are similar to Basis Pursuit, with the addition of solutions where the basis pursuit solution was $zero$.


Discussion

This chapter has provided an example of how a 1x1 2D Convolutional Neural Network can be used to improve computational cost for DEM inversion. Future improvements to DeepEM can come in a few ways:

First, by using both the original, and synthesised data from the DEM, the ability of the DEM to recover the original or supplementary data can be used as a additional term in the loss function. Furthermore, we could use a number of additional data to further constrain the DEMs:

  • Use SDO/AIA Data to correct the DEMs
  • Use MEGS-A EUV to correct the DEMs
  • Use Hard X-ray observations to correct the DEMs

Appendix A: What has the CNN learned about our training set?

If we say that our training set is now our test set, we can see how much the CNN has learned about the training data:

In [20]:
X_test = X_train 
y_test = y_train
In [21]:
dem_data_test = DEMdata(X_train, y_train, X_test, y_test, X_val, y_val, split='test')
dem_loader = DataLoader(dem_data_test, batch_size=1)

dummy, test_loss, dem_pred_trn, dem_in_test_trn = valtest_model(dem_loader, criterion)
In [22]:
AIA_out = img_scale(dem2aia(Variable(em_unscale(dem_in_test_trn))).detach().cpu().numpy())
AIA_out_DeepEM = img_scale(dem2aia(Variable(em_unscale(dem_pred_trn))).detach().cpu().numpy())
In [23]:
fig,ax=plt.subplots(ncols=3,figsize=(9*2,9))

ax[1].imshow(X_test[0,4,:,:],vmin=0.25,vmax=25,cmap='Greys_r')
ax[1].text(5, 512.-7.5, '${\it SDO}$/AIA 211 $\AA$', color="white", size='large')
ax[0].imshow(X_test[0,2,:,:],vmin=0.01,vmax=30,cmap='Greys_r')
ax[0].text(5, 512.-7.5, '${\it SDO}$/AIA 171 $\AA$', color="white", size='large')
ax[2].imshow(X_test[0,0,:,:],vmin=0.01,vmax=3,cmap='Greys_r')
ax[2].text(5, 512.-7.5, '${\it SDO}$/AIA 94 $\AA$', color="white", size='large')

for axes in ax:
    axes.get_xaxis().set_visible(False)
    axes.get_yaxis().set_visible(False)
    
fig,ax=plt.subplots(ncols=3,figsize=(9*2,9))

ax[1].imshow(dem_in_test_trn[0,8,:,:].cpu().detach().numpy(),vmin=0.25,vmax=10,cmap='viridis')
ax[1].text(5, 512.-7.5, 'Basis Pursuit DEM log$_{10}$T ~ 6.3', color="white", size='large')
#ax[1].scatter(x=[400], y=[80], c='w', s=50, marker='x') 
ax[0].imshow(dem_in_test_trn[0,4,:,:].cpu().detach().numpy(),vmin=0.01,vmax=3,cmap='viridis')
ax[0].text(5, 512.-7.5, 'Basis Pursuit DEM log$_{10}$T ~ 5.9', color="white", size='large')
ax[2].imshow(dem_in_test_trn[0,15,:,:].cpu().detach().numpy(),vmin=0.01,vmax=3,cmap='viridis')
ax[2].text(5, 512.-7.5, 'Basis Pursuit DEM log$_{10}$T ~ 7.0', color="white", size='large')

for axes in ax:
    axes.get_xaxis().set_visible(False)
    axes.get_yaxis().set_visible(False)
    
fig,ax=plt.subplots(ncols=3,figsize=(9*2,9))

ax[1].imshow(dem_pred_trn[0,8,:,:].cpu().detach().numpy(),vmin=0.25,vmax=10,cmap='viridis')
ax[1].text(5, 512.-7.5, 'DeepEM log$_{10}$T ~ 6.3', color="white", size='large')
#ax[1].scatter(x=[400], y=[80], c='w', s=50, marker='x') 
ax[0].imshow(dem_pred_trn[0,4,:,:].cpu().detach().numpy(),vmin=0.01,vmax=3,cmap='viridis')
ax[0].text(5, 512.-7.5, 'DeepEM log$_{10}$T ~ 5.9', color="white", size='large')
ax[2].imshow(dem_pred_trn[0,15,:,:].cpu().detach().numpy(),vmin=0.01,vmax=3,cmap='viridis')
ax[2].text(5, 512.-7.5, 'DeepEM log$_{10}$T ~ 7.0', color="white", size='large')

for axes in ax:
    axes.get_xaxis().set_visible(False)
    axes.get_yaxis().set_visible(False)
In [24]:
fig,ax=plt.subplots(ncols=3,figsize=(9*2,9))

ax[1].imshow(X_test[0,4,:,:],vmin=0.25,vmax=25,cmap='Greys_r')
ax[1].text(5, 512.-7.5, '${\it SDO}$/AIA 211 $\AA$', color="white", size='large')
ax[0].imshow(X_test[0,2,:,:],vmin=0.01,vmax=30,cmap='Greys_r')
ax[0].text(5, 512.-7.5, '${\it SDO}$/AIA 171 $\AA$', color="white", size='large')
ax[2].imshow(X_test[0,0,:,:],vmin=0.01,vmax=3,cmap='Greys_r')
ax[2].text(5, 512.-7.5, '${\it SDO}$/AIA 94 $\AA$', color="white", size='large')

for axes in ax:
    axes.get_xaxis().set_visible(False)
    axes.get_yaxis().set_visible(False)
    
fig,ax=plt.subplots(ncols=3,figsize=(9*2,9))

ax[1].imshow(AIA_out[0,4,:,:],vmin=0.25,vmax=25,cmap='Greys_r')
ax[1].text(5, 512.-7.5, 'Basis Pursuit Synthesized 211 $\AA$', color="white", size='large')
ax[0].imshow(AIA_out[0,2,:,:],vmin=0.01,vmax=30,cmap='Greys_r')
ax[0].text(5, 512.-7.5, 'Basis Pursuit Synthesized 171 $\AA$', color="white", size='large')
ax[2].imshow(AIA_out[0,0,:,:],vmin=0.01,vmax=3,cmap='Greys_r')
ax[2].text(5, 512.-7.5, 'Basis Pursuit Synthesized 94 $\AA$', color="white", size='large')

for axes in ax:
    axes.get_xaxis().set_visible(False)
    axes.get_yaxis().set_visible(False)
    
fig,ax=plt.subplots(ncols=3,figsize=(9*2,9))

ax[1].imshow(AIA_out_DeepEM[0,4,:,:],vmin=0.25,vmax=25,cmap='Greys_r')
ax[1].text(5, 512.-7.5, 'DeepEM Synthesized 211 $\AA$', color="white", size='large')
ax[0].imshow(AIA_out_DeepEM[0,2,:,:],vmin=0.01,vmax=30,cmap='Greys_r')
ax[0].text(5, 512.-7.5, 'DeepEM Synthesized 171 $\AA$', color="white", size='large')
ax[2].imshow(AIA_out_DeepEM[0,0,:,:],vmin=0.01,vmax=3,cmap='Greys_r')
ax[2].text(5, 512.-7.5, 'DeepEM Synthesized 94 $\AA$', color="white", size='large')

for axes in ax:
    axes.get_xaxis().set_visible(False)
    axes.get_yaxis().set_visible(False)